/*
 * Copyright (c) 2011-2022, baomidou (jobob@qq.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.baomidou.mybatisplus.core.conditions;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.core.conditions.query.JoinQueryWrapper;
import com.baomidou.mybatisplus.core.conditions.segments.MergeSegments;
import com.baomidou.mybatisplus.core.enums.BaseFuncEnum;
import com.baomidou.mybatisplus.core.metadata.TableFieldInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.Assert;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.LambdaUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.core.toolkit.support.ColumnCache;

import lombok.Data;

/**
 * 关联使用的wrapper
 * <p>统一处理表join</p>
 * 本类部分代码移植自mybatis-plus-join
 *
 * @author wanglei
 * @since 2022-03-17
 */
@SuppressWarnings("serial")
public abstract class AbstractJoinWrapper<T, R, Children extends AbstractJoinWrapper<T, R, Children>> extends AbstractWrapper<T, R, Children>
    implements JoinQueryWrapper<T> {

    /**
     * 查询的字段
     */
    protected List<SelectColumn> selectColumns = new ArrayList<>();

    /**
     * 函数查询字段
     */
    protected List<SharedString> funSqlSelect = new ArrayList<>();

    /**
     * 忽略查询的字段
     */
    protected List<SelectColumn> ignoreColumns = new ArrayList<>();

    /**
     * 关联的表字段
     */
    protected StringBuilder joinFroms = new StringBuilder();

    /**
     * 类对应的别名
     */
    protected Map<Class<?>, String> aliasMap = new LinkedHashMap<>();


    /**
     * 查询字段 sql
     */
    protected SharedString sqlSelect = new SharedString();

    /**
     * 是否是count的sql
     */
    protected boolean isCount = false;

    /**
     * 是否join了其他表
     */
    protected boolean isAppendMainLogicDelete;


    public AbstractJoinWrapper(Class<?> mainClass) {
        String alias = parseAlias(mainClass);
        aliasMap.put(mainClass, alias);
        joinFroms.append(" ").append(alias).append(" ");
    }

    /**
     * 必要的初始化
     */
    @Override
    protected void initNeed() {
        paramNameSeq = new AtomicInteger(0);
        paramNameValuePairs = new HashMap<>(16);
        expression = new MergeSegments();
        lastSql = SharedString.emptyString();
        sqlComment = SharedString.emptyString();
        sqlFirst = SharedString.emptyString();
    }

    /**
     * 标记是否为count查询
     *
     * @param isCount true为count查询，false为普通查询
     */
    public void setIsCount(boolean isCount) {
        this.isCount = isCount;
    }

    /**
     * 格式化别名
     *
     * @param clazz 类
     * @return 别名
     */
    public String parseAlias(Class<?> clazz) {
        return aliasMap.computeIfAbsent(clazz, k -> {
            String className = k.getSimpleName();
            return className.substring(0, 1).toLowerCase() + className.substring(1);
        });
    }

    /**
     * 添加join
     *
     * @param joinClass join的表
     * @param joinType  join类型 inner join left join right join
     */
    protected <LEFT, RIGHT> void pubJoin(Class<?> joinClass, LEFT leftProperty, RIGHT rightProperty, String joinType) {
        if (aliasMap.containsKey(joinClass)) {
            return;
//            throw ExceptionUtils.mpe("A table cannot be joined more than once, Class: \"%s\".", joinClass.getName());
        }
        TableInfo tableInfo = TableInfoHelper.getTableInfo(joinClass);
        if (tableInfo == null) {
            throw ExceptionUtils.mpe("cat not get tableInfo from Class: \"%s\".", joinClass.getName());
        }
        String joinAlias = parseAlias(joinClass);
        aliasMap.put(joinClass, joinAlias);

        StringBuilder join = new StringBuilder(" " + joinType + " " + tableInfo.getTableName() + " " + joinAlias + " on ");

        if (rightProperty == null || leftProperty == null) {
            TableField tableField = null;
            TableField tempField;
            Class<?> leftClass = null;
            List<Field> leftFields;
            boolean isBasicField = false;
            String leftAlias = null;
            Field leftField = null;
            for (Class<?> clazz : aliasMap.keySet()) {
                //如果找到left 字段就跳出循环
                if (leftField != null) {
                    break;
                }
                // 如果target是目标class 或者field本身就是目标class
                leftFields = TableInfoHelper.getAllFields(clazz);
                for (Field field : leftFields) {
                    if (field.isAnnotationPresent(TableField.class)) {
                        tempField = field.getAnnotation(TableField.class);
                    /*
                     这块代码比较绕，这里解释下，加入userpo有orgId属性，有Org org属性 或者List<Org> orgs属性
                     其中 org_id 本身是user表的一个字段，所以这个字段加注解 @TableField(target=Org.class)即可，orgid 一般是 long/int/String类型 类型名是 java.lang开头的
                     如果是List<Org> orgs 则需要加注解 @TableField(target=Org.class, modelFields="org_id")
                     如果是Org org 则需要加注解 @TableField(modelFields="org_id")
                     下面代码可以兼容以上三种情况  但是可能还有其他的坑，遇到在填
                     */

                        if ((joinClass.equals(tempField.target()) || joinClass.equals(field.getType())) &&
                            ((tempField.modelFields().length != 0 && StringUtils.isNotEmpty(tempField.modelFields()[0]))
                                ||
                                (field.getType().getName().startsWith("java.") && !Collection.class.isAssignableFrom(field.getType())))) {
                            tableField = tempField;
                            leftClass = clazz;
                            leftField = field;
                            leftAlias = aliasMap.get(clazz);
                            if (field.getType().getName().startsWith("java.")) {
                                isBasicField = true;
                            }
                            break;
                        }
                    }
                }
            }
            if (tableField == null) {
                throw ExceptionUtils.mpe("cat not find relation from Class: \"%s\".", joinClass.getName());
            }
            String[] leftClassField = tableField.modelFields();
            String[] rightClassField = tableField.targetFields();

            // 左边的如果为空，并且字段是java自带的类型，则使用@TableField的value，否则使用主键
            if (leftClassField == null || leftClassField.length == 0 || StringUtils.isEmpty(leftClassField[0])) {
                leftClassField = new String[]{isBasicField ? leftField.getName() : TableInfoHelper.getTableInfo(leftClass).getKeyProperty()};
            }
            // 右边的如果为空，则右边的设置主键
            if (rightClassField == null || rightClassField.length == 0 || StringUtils.isEmpty(rightClassField[0])) {
                tableInfo = TableInfoHelper.getTableInfo(joinClass);
                rightClassField = new String[]{tableInfo.getKeyProperty()};
            }
            if (leftClassField.length != rightClassField.length) {
                throw ExceptionUtils.mpe("@TableField The left and right lengths are not equal : \"%s\".", joinClass.getName());
            }
            for (int i = 0; i < leftClassField.length; i++) {
                if (i > 0) {
                    join.append(" and ");
                }
                join.append(leftAlias).append(".").append(getCache(leftClass, leftClassField[i]).getColumn()).append("=").append(joinAlias).append(".").append(getCache(joinClass, rightClassField[i]).getColumn());
            }
        } else {
            // 如果是使用自定义构建条件，那就交由底层类实现
            join.append(customBuildJoin(leftProperty, rightProperty));
        }

        joinFroms.append(join);
        initLogicDelete(joinClass, joinType);

    }

    /**
     * 自定义构建join on 后半段
     *
     * @param leftProperty
     * @param rightProperty
     */
    protected <LEFT, RIGHT> String customBuildJoin(LEFT leftProperty, RIGHT rightProperty) {
        return "";
    }

    /**
     * 初始化软删除的条件
     *
     * @param joinClass join的类
     * @param joinType  join类型
     */
    protected abstract void initLogicDelete(Class<?> joinClass, String joinType);


    /**
     * 查询条件 SQL 片段
     */
    @Override
    public String getSqlSelect() {
        if (this.isCount) {
            return StringPool.ASTERISK;
        }
        if (StringUtils.isBlank(sqlSelect.getStringValue())) {
            if (CollectionUtils.isNotEmpty(ignoreColumns)) {
                selectColumns.removeIf(c -> c.getFuncEnum() == null && ignoreColumns.stream().anyMatch(i ->
                    i.getClazz() == c.getClazz() && Objects.equals(c.getColumnName(), i.getColumnName())));
            }
            Set<Class<?>> notSetSelectClass = new HashSet<>(aliasMap.keySet());
            String s = selectColumns.stream().map(i -> {
                //干掉手动设置过的
                notSetSelectClass.remove(i.getClazz());
                String str = aliasMap.get(i.getClazz()) + StringPool.DOT + i.getColumnName();
                return (i.getFuncEnum() == null ? str : String.format(i.getFuncEnum().getSql(), str)) +
                    (StringUtils.isBlank(i.getAlias()) ? StringPool.EMPTY : (Constants.AS + i.getAlias()));
            }).collect(Collectors.joining(StringPool.COMMA));
            //没有手动设置过的查*
            if (notSetSelectClass.size() > 0) {
                //如果不等于空则前面拼接个逗号
                if (StringUtils.isNotEmpty(s)) {
                    s += StringPool.COMMA;
                }
                s += parseNotAssignTblColumn(notSetSelectClass);
            }
            sqlSelect.setStringValue(s);
            sqlSelect.setStringValue(StringUtils.parseSqlSelect(funSqlSelect, sqlSelect));
        }
        return sqlSelect.getStringValue();
    }

    /**
     * 格式化没有手动指定列的
     *
     * @param classes 哪些关联的表没手动设置列
     * @return 表的列，带别名的
     */
    public String parseNotAssignTblColumn(Collection<Class<?>> classes) {
        StringBuilder select = new StringBuilder();
        TableInfo tableInfo;
        for (Class aClass : classes) {
            tableInfo = TableInfoHelper.getTableInfo(aClass);
            String alias = aliasMap.get(aClass);
            Set<String> columns = tableInfo.getFieldList().stream().map(TableFieldInfo::getColumn).collect(Collectors.toSet());
            columns.add(tableInfo.getKeyColumn());
            if (select.length() > 0) {
                select.append(StringPool.COMMA);
            }
            select.append(columns.stream().map(column -> alias + StringPool.DOT + column + Constants.AS + alias + StringPool.UNDERLINE + column)
                .collect(Collectors.joining(StringPool.COMMA)));
        }
        return select.toString();
    }

    /**
     * 获取属性缓存
     *
     * @param modelClass class
     * @param property   属性
     * @return 缓存
     */
    public static ColumnCache getCache(Class<?> modelClass, String property) {
        Map<String, ColumnCache> cacheMap = LambdaUtils.getColumnMap(modelClass);
        return cacheMap.get(property.toUpperCase());
    }

    public Children select(Class<T> entityClass, Predicate<TableFieldInfo> predicate) {
        TableInfo info = TableInfoHelper.getTableInfo(entityClass);
        Assert.notNull(info, "table can not be find");
        info.getFieldList().stream().filter(predicate).collect(Collectors.toList()).forEach(
            i -> selectColumns.add(SelectColumn.of(entityClass, i.getColumn())));
        return (Children) typedThis;
    }

    /**
     * 格式化逻辑删除 未删除时 条件value
     *
     * @param field 字段
     * @return 逻辑删除条件value
     */
    protected Object parseLogicNotDeleteValue(TableFieldInfo field) {
        Class fieldType = field.getPropertyType();
        if (Integer.class.equals(fieldType)) {
            return Integer.parseInt(field.getLogicNotDeleteValue());
        }
        if (Long.class.equals(fieldType)) {
            return Long.parseLong(field.getLogicNotDeleteValue());
        }
        if (Boolean.class.equals(fieldType)) {
            return Boolean.parseBoolean(field.getLogicNotDeleteValue());
        }
        return field.getLogicNotDeleteValue();
    }


    @Override
    public boolean isJoin() {
        return true;
    }

    /**
     * 获取连表部分语句
     */
    @Override
    public String getFrom() {
        return this.joinFroms.toString();
    }

    /**
     * select字段
     */
    @Data
    public static class SelectColumn {

        /**
         * 字段实体类
         */
        private Class<?> clazz;

        /**
         * 数据库字段名
         */
        private String columnName;

        /**
         * 字段别名
         */
        private String alias;

        /**
         * 字段函数
         */
        private BaseFuncEnum funcEnum;


        private SelectColumn(Class<?> clazz, String columnName, String alias, BaseFuncEnum funcEnum) {
            this.clazz = clazz;
            this.columnName = columnName;
            this.alias = alias;
            if (this.alias == null) {
                String className = clazz.getSimpleName();
                //自动设置别名
                this.alias = className.substring(0, 1).toLowerCase() + className.substring(1) + StringPool.UNDERLINE + columnName;
            }
            this.funcEnum = funcEnum;
        }

        public static SelectColumn of(Class<?> clazz, String columnName) {
            return new SelectColumn(clazz, columnName, null, null);
        }

        public static SelectColumn of(Class<?> clazz, String columnName, String alias) {
            return new SelectColumn(clazz, columnName, alias, null);
        }

        public static SelectColumn of(Class<?> clazz, String columnName, String alias, BaseFuncEnum funcEnum) {
            return new SelectColumn(clazz, columnName, alias, funcEnum);
        }
    }
}

